{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9c10082182ad6fbd",
   "metadata": {},
   "source": "# Estimating the uncertainties in the exoplanet masses"
  },
  {
   "cell_type": "markdown",
   "id": "89d6def8797ffbdd",
   "metadata": {},
   "source": ""
  },
  {
   "cell_type": "markdown",
   "id": "5f1ec24e9c03c240",
   "metadata": {},
   "source": "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/scikit-learn-contrib/MAPIE/blob/master/notebooks/regression/exoplanets.ipynb)\n"
  },
  {
   "cell_type": "markdown",
   "id": "88604cdcf55449a9",
   "metadata": {},
   "source": "In this notebook, we quantify the uncertainty in exoplanet masses predicted by several machine learning models, based on the exoplanet properties. To this aim, we use the exoplanet dataset downloaded from the [NASA Exoplanet Archive](https://exoplanetarchive.ipac.caltech.edu/) and estimate the prediction intervals using the methods implemented in MAPIE."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee3c90cd0a9cabd",
   "metadata": {},
   "outputs": [],
   "source": [
    "install_mapie = False\n",
    "if install_mapie:\n",
    "    !pip install mapie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f2cc6cd5bc57079",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing_extensions import TypedDict\n",
    "from typing import Union\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import (\n",
    "    OneHotEncoder,\n",
    "    OrdinalEncoder,\n",
    "    PolynomialFeatures,\n",
    "    RobustScaler\n",
    ")\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "from mapie.metrics.regression import regression_coverage_score\n",
    "from mapie.regression import CrossConformalRegressor, JackknifeAfterBootstrapRegressor\n",
    "from mapie.subsample import Subsample\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "347c5cb4e9cadaf2",
   "metadata": {},
   "source": "## 1. Data Loading"
  },
  {
   "cell_type": "markdown",
   "id": "325d138e1af4a9cb",
   "metadata": {},
   "source": "Let's start by loading the `exoplanets` dataset and looking at the main information."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0c27df729db836",
   "metadata": {},
   "outputs": [],
   "source": [
    "url_file = \"https://raw.githubusercontent.com/scikit-learn-contrib/MAPIE/master/notebooks/regression/exoplanets_mass.csv\"\n",
    "exo_df = pd.read_csv(url_file, index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a61118b0fb649394",
   "metadata": {},
   "outputs": [],
   "source": [
    "exo_df.info()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b951cd66a4cf7aeb",
   "metadata": {},
   "source": "The dataset contains 21 features giving complementary information about the properties of the discovered planet, the star around which the planet revolves, together with the type of discovery method. 7 features are categorical, and 14 are continuous."
  },
  {
   "cell_type": "markdown",
   "id": "2ee380ae23a88937",
   "metadata": {},
   "source": "Some properties show high variance among exoplanets and stars due to the astronomical nature of such systems. We therefore decide to use a log transformation for the following features to approach a normal distribution."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4d9a30207568e2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "exo_df[\"Stellar_Mass_[Solar_mass]\"] = exo_df[\"Stellar_Mass_[Solar_mass]\"].replace(0, np.nan)\n",
    "vars2log = [\n",
    "    \"Planet_Orbital_Period_[day]\",\n",
    "    \"Planet_Orbital_SemiMajorAxis_[day]\",\n",
    "    \"Planet_Radius_[Earth_radius]\",\n",
    "    \"Planet_Mass_[Earth_mass]\",\n",
    "    \"Stellar_Radius_[Solar_radius]\",\n",
    "    \"Stellar_Mass_[Solar_mass]\",\n",
    "    \"Stellar_Effective_Temperature_[K]\"\n",
    "]\n",
    "for var in vars2log:\n",
    "    exo_df[var+\"_log\"] = np.log(exo_df[var])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f161bd4435ee562",
   "metadata": {},
   "outputs": [],
   "source": [
    "vars2keep = list(set(exo_df.columns) - set(vars2log))\n",
    "exo_df = exo_df[vars2keep]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90ca4d6c1991072b",
   "metadata": {},
   "outputs": [],
   "source": [
    "exo_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d7a4e6f77ffc77e",
   "metadata": {},
   "source": "Throughout this tutorial, the target variable will be `Planet_Mass_[Earth_mass]_log`."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65ea679dc51c5fc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "target = \"Planet_Mass_[Earth_mass]_log\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eac7e90b2b0d39cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_cols = list(exo_df.columns[exo_df.dtypes == \"float64\"])\n",
    "cat_cols = list(exo_df.columns[exo_df.dtypes != \"float64\"])\n",
    "exo_df[cat_cols] = exo_df[cat_cols].astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc2ac78c11d2d6d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "planet_cols = [col for col in num_cols if \"Planet_\" in col]\n",
    "star_cols = [col for col in num_cols if \"Stellar_\" in col]\n",
    "system_cols = [col for col in num_cols if \"System_\" in col]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39e1f6828e7015f9",
   "metadata": {},
   "source": "## 2. Data visualization"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7239d6da41af3d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.pairplot(exo_df[planet_cols])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9a11057c02d3ca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.pairplot(exo_df[star_cols])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2459fa71a4b09ad3",
   "metadata": {},
   "source": "## 3. Data preprocessing"
  },
  {
   "cell_type": "markdown",
   "id": "9a96a34e520b413f",
   "metadata": {},
   "source": "In this section, we perform a simple preprocessing of the dataset in order to impute the missing values and encode the categorical features."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a11454a231d2cdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "endos = list(set(exo_df.columns) - set([target]))\n",
    "X = exo_df[endos]\n",
    "y = exo_df[target]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "245b4d24119661c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_cols = list(X.columns[X.dtypes == \"float64\"])\n",
    "cat_cols = list(X.columns[X.dtypes != \"float64\"])\n",
    "X[cat_cols] = X[cat_cols].astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c11c21129a6455",
   "metadata": {},
   "outputs": [],
   "source": [
    "imputer_num = SimpleImputer(strategy=\"mean\")\n",
    "scaler_num = RobustScaler()\n",
    "imputer_cat = SimpleImputer(strategy=\"constant\", fill_value=-1)\n",
    "encoder_cat = OneHotEncoder(\n",
    "    categories=\"auto\",\n",
    "    drop=None,\n",
    "    handle_unknown=\"ignore\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7427c915309948d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "numerical_transformer = Pipeline(\n",
    "    steps=[(\"imputer\", imputer_num), (\"scaler\", scaler_num)]\n",
    ")\n",
    "categorical_transformer = Pipeline(\n",
    "    steps=[(\"ordinal\", OrdinalEncoder()), (\"imputer\", imputer_cat), (\"encoder\", encoder_cat)]\n",
    ")\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        (\"numerical\", numerical_transformer, num_cols),\n",
    "        (\"categorical\", categorical_transformer, cat_cols)\n",
    "    ],\n",
    "    remainder=\"drop\",\n",
    "    sparse_threshold=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b613bbd4fee1dab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    X, y, test_size=0.2, random_state=42, shuffle=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99815400fe893a22",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = preprocessor.fit_transform(X_train)\n",
    "X_test = preprocessor.transform(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc9be8d89d9b20f3",
   "metadata": {},
   "source": "## 4. First estimation of the uncertainties with MAPIE"
  },
  {
   "cell_type": "markdown",
   "id": "34f80e4b0ec3e262",
   "metadata": {},
   "source": "### Uncertainty estimation"
  },
  {
   "cell_type": "markdown",
   "id": "12cf6b947996819c",
   "metadata": {},
   "source": "Here, we build our first prediction intervals with MAPIE. To this aim, we adopt the CV+ strategy with 5 folders, using `method=\"plus\"` and `cv=KFold(n_splits=5, shuffle=True)` as input arguments."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd43ae50bb955af5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_regressor(name):\n",
    "    if name == \"linear\":\n",
    "        mdl = LinearRegression()\n",
    "    elif name == \"polynomial\":\n",
    "        degree_polyn = 2\n",
    "        mdl = Pipeline(\n",
    "            [\n",
    "                (\"poly\", PolynomialFeatures(degree=degree_polyn)),\n",
    "                (\"linear\", LinearRegression())\n",
    "            ]\n",
    "        )\n",
    "    elif name == \"random_forest\":\n",
    "        mdl = RandomForestRegressor()\n",
    "    return mdl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f1c3454a72dcae1",
   "metadata": {},
   "outputs": [],
   "source": [
    "mdl = get_regressor(\"random_forest\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4dd6bf5dd404815",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidence_level = np.arange(0.05, 1, 0.05)\n",
    "mapie = CrossConformalRegressor(\n",
    "    estimator=mdl,\n",
    "    confidence_level=confidence_level,\n",
    "    method=\"plus\",\n",
    "    cv=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "517009cca18bf9a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapie.fit_conformalize(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4f12b07eb9fcd1c",
   "metadata": {},
   "source": "We build prediction intervals for a range of alpha values between 0 and 1."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6827c93198e211b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train_pred, y_train_pis = mapie.predict_interval(X_train)\n",
    "y_test_pred, y_test_pis = mapie.predict_interval(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f96144ade7c69a2d",
   "metadata": {},
   "source": "### Visualization"
  },
  {
   "cell_type": "markdown",
   "id": "d2d51d8d2556b946",
   "metadata": {},
   "source": "The following function offers to visualize the error bars estimated by MAPIE for the selected method and the given confidence level."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a86c08738a8c9c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_predictionintervals(\n",
    "    y_train,\n",
    "    y_train_pred,\n",
    "    y_train_pred_low,\n",
    "    y_train_pred_high,\n",
    "    y_test,\n",
    "    y_test_pred,\n",
    "    y_test_pred_low,\n",
    "    y_test_pred_high,\n",
    "    suptitle: str,\n",
    ") -> None:\n",
    "    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))\n",
    "    \n",
    "    ax1.errorbar(\n",
    "        x=y_train,\n",
    "        y=y_train_pred,\n",
    "        yerr=(y_train_pred - y_train_pred_low, y_train_pred_high - y_train_pred),\n",
    "        alpha=0.8,\n",
    "        label=\"train\",\n",
    "        fmt=\".\",\n",
    "    )\n",
    "    ax1.errorbar(\n",
    "        x=y_test,\n",
    "        y=y_test_pred,\n",
    "        yerr=(y_test_pred - y_test_pred_low, y_test_pred_high - y_test_pred),\n",
    "        alpha=0.8,\n",
    "        label=\"test\",\n",
    "        fmt=\".\",\n",
    "    )\n",
    "    ax1.plot(\n",
    "        [y_train.min(), y_train.max()],\n",
    "        [y_train.min(), y_train.max()],\n",
    "        color=\"gray\",\n",
    "        alpha=0.5,\n",
    "    )\n",
    "    ax1.set_xlabel(\"True values\", fontsize=12)\n",
    "    ax1.set_ylabel(\"Predicted values\", fontsize=12)\n",
    "    ax1.legend()\n",
    "    \n",
    "    ax2.scatter(\n",
    "        x=y_train, y=y_train_pred_high - y_train_pred_low, alpha=0.8, label=\"train\", marker=\".\"\n",
    "    )\n",
    "    ax2.scatter(x=y_test, y=y_test_pred_high - y_test_pred_low, alpha=0.8, label=\"test\", marker=\".\")\n",
    "    ax2.set_xlabel(\"True values\", fontsize=12)\n",
    "    ax2.set_ylabel(\"Interval width\", fontsize=12)\n",
    "    ax2.set_xscale(\"linear\")\n",
    "    ax2.set_ylim([0, np.max(y_test_pred_high - y_test_pred_low)*1.1])\n",
    "    ax2.legend()\n",
    "    std_all = np.concatenate([\n",
    "        y_train_pred_high - y_train_pred_low, y_test_pred_high - y_test_pred_low\n",
    "    ])\n",
    "    type_all = np.array([\"train\"] * len(y_train) + [\"test\"] * len(y_test))\n",
    "    x_all = np.arange(len(std_all))\n",
    "    order_all = np.argsort(std_all)\n",
    "    std_order = std_all[order_all]\n",
    "    type_order = type_all[order_all]\n",
    "    ax3.scatter(\n",
    "        x=x_all[type_order == \"train\"],\n",
    "        y=std_order[type_order == \"train\"],\n",
    "        alpha=0.8,\n",
    "        label=\"train\",\n",
    "        marker=\".\",\n",
    "    )\n",
    "    ax3.scatter(\n",
    "        x=x_all[type_order == \"test\"],\n",
    "        y=std_order[type_order == \"test\"],\n",
    "        alpha=0.8,\n",
    "        label=\"test\",\n",
    "        marker=\".\",\n",
    "    )\n",
    "    ax3.set_xlabel(\"Order\", fontsize=12)\n",
    "    ax3.set_ylabel(\"Interval width\", fontsize=12)\n",
    "    ax3.legend()\n",
    "    ax1.set_title(\"True vs predicted values\")\n",
    "    ax2.set_title(\"Prediction interval width vs true values\")\n",
    "    ax3.set_title(\"Ordered prediction interval width\")\n",
    "    plt.suptitle(suptitle, size=20)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "718413639cfc4cce",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha_plot = int(np.where((confidence_level > 0.89) & (confidence_level < 0.91))[0])\n",
    "plot_predictionintervals(\n",
    "    y_train,\n",
    "    y_train_pred,\n",
    "    y_train_pis[:, 0, alpha_plot],\n",
    "    y_train_pis[:, 1, alpha_plot],\n",
    "    y_test,\n",
    "    y_test_pred,\n",
    "    y_test_pis[:, 0, alpha_plot],\n",
    "    y_test_pis[:, 1, alpha_plot],\n",
    "    \"Prediction intervals for confidence_level=0.9\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cc1c20d579d0050",
   "metadata": {},
   "source": "## 5. Comparison of the uncertainty quantification methods"
  },
  {
   "cell_type": "markdown",
   "id": "2c7a893bc94dc698",
   "metadata": {},
   "source": "In the last section, we compare the calibration of several uncertainty-quantification methods provided by MAPIE using Random Forest as base model. To this aim, we build so-called \"calibration plots\" which compare the effective marginal coverage obtained on the test set with the target $1-\\alpha$ coverage."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73af775d3e262e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "Params = TypedDict(\"Params\", {\"method\": str, \"cv\": Union[int, Subsample]})\n",
    "STRATEGIES = {\n",
    "    \"cv\": {\n",
    "        \"class\": CrossConformalRegressor,\n",
    "        \"init_params\": dict(method=\"base\", cv=5),\n",
    "    },\n",
    "    \"cv_plus\": {\n",
    "        \"class\": CrossConformalRegressor,\n",
    "        \"init_params\": dict(method=\"plus\", cv=5),\n",
    "    },\n",
    "    \"cv_minmax\": {\n",
    "        \"class\": CrossConformalRegressor,\n",
    "        \"init_params\": dict(method=\"minmax\", cv=5),\n",
    "    },\n",
    "    \"jackknife_plus_ab\": {\n",
    "        \"class\": JackknifeAfterBootstrapRegressor,\n",
    "        \"init_params\": dict(method=\"plus\", resampling=20),\n",
    "    },\n",
    "}\n",
    "mdl = get_regressor(\"random_forest\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "353f8514b22556ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred, y_pis, scores = {}, {}, {}\n",
    "RANDOM_STATE = 1\n",
    "for strategy_name, strategy_params in STRATEGIES.items():\n",
    "    init_params = strategy_params[\"init_params\"]\n",
    "    class_ = strategy_params[\"class\"]\n",
    "    mapie = class_(\n",
    "        mdl, confidence_level=confidence_level,\n",
    "        random_state=RANDOM_STATE, **init_params\n",
    "    )\n",
    "    mapie.fit_conformalize(X_train, y_train)\n",
    "    y_pred[strategy_name], y_pis[strategy_name] = mapie.predict_interval(X_test)\n",
    "    scores[strategy_name] = regression_coverage_score(y_test, y_pis[strategy_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "937ea2ccad542547",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7, 6))\n",
    "plt.xlabel(\"Target coverage (1 - alpha)\")\n",
    "plt.ylabel(\"Effective coverage\")\n",
    "for strategy, params in STRATEGIES.items():\n",
    "    plt.plot(confidence_level, scores[strategy], label=strategy)\n",
    "plt.plot([0, 1], [0, 1], ls=\"--\", color=\"k\")\n",
    "plt.legend(loc=[1, 0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2c29b1a0fdf9a8",
   "metadata": {},
   "source": "The calibration plot clearly demonstrates that the \"naive\" method underestimates the coverage by giving too narrow prediction intervals, due to the fact that they are built from training data. All other methods show much more robust calibration plots : the effective coverages follow almost linearly the expected coverage levels."
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,md"
  },
  "kernelspec": {
   "display_name": ".venv-doc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
